package com.diven.spark.ml.learn.pipeline

import com.diven.spark.ml.learn.core.{BaseSpark, BaseTest}
import com.diven.spark.ml.learn.feature.IndexToStringTest
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.feature.{ColumnFilter, IndexToString, StringIndexer, StringIndexerModel, VectorAssembler, VectorIndexer}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.tree.InternalNode


object DecisionTreeTest extends BaseTest {

  def apply(): BaseSpark = new DecisionTreeTest()

}

class DecisionTreeTest extends BaseSpark {

  override def getDataPath(): String = {
    this.irisPath
  }

  /**
   * 执行任务
   */
  override def execute(dataFrame: DataFrame): Unit = {
    // 切分数据为训练集与测试集
    val Array(trainingData, testData) = dataFrame.randomSplit(Array(0.7, 0.3))
    //字段过滤
    var columnFilter = new ColumnFilter().setFilterColumns(Array("SepalLength", "SepalWidth", "PetalLength", "PetalWidth", "Species"))

    //转换数据
    var vectorAssembler = new VectorAssembler()
      .setInputCols(Array("SepalLength", "SepalWidth", "PetalLength", "PetalWidth"))
      .setOutputCol("features")

    // 对标签列建立索引
    val labelIndexer = new StringIndexer()
      .setInputCol("Species")
      .setOutputCol("indexedLabel")
      .fit(dataFrame)

    // 对特征建立索引，离散值处理
    val featureIndexer = new VectorIndexer()
      .setInputCol("features")
      .setOutputCol("indexedFeatures")
      .setMaxCategories(4) // features with > 4 distinct values are treated as continuous

    // 获取训练模型方式
    val dt = new DecisionTreeClassifier()
      .setLabelCol("indexedLabel")
      .setFeaturesCol("indexedFeatures")

    //将预测的值装换为字符标签
    val labelConverter = new IndexToString()
      .setInputCol("prediction")
      .setOutputCol("predictedSpecies")
      .setLabels(labelIndexer.labels)

    // Chain indexers and tree in a Pipeline
    val pipeline = new Pipeline().setStages(Array(columnFilter, vectorAssembler, labelIndexer, featureIndexer, dt, labelConverter))

    // 获取训练的模型
    val model = pipeline.fit(trainingData)
    // 获取预测结果
    val predictions = model.transform(testData)
//
//
//    println("count=" + predictions.count())
//    predictions.show(1000)
//
//    for (field <- predictions.schema) {
//      println(field)
//    }
//
    // 预测模型的校验
    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("indexedLabel")
      .setPredictionCol("prediction")

    val accuracy = evaluator.evaluate(predictions)
    println("Test Error = " + (1.0 - accuracy))

//
//    for (tf <- model.stages) {
//      println(tf.getClass)
//    }
//    println(model.stages)
//    val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel]
//    println("Learned classification tree model:\n" + treeModel.toDebugString)
//
//    var node = treeModel.rootNode.asInstanceOf[InternalNode]
//
//    println(node.toString())
  }

}
